from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

import os.path as osp
ROOT = osp.dirname(osp.abspath(__file__))

setup(
    name='droid_backends',
    ext_modules=[
        CUDAExtension('droid_backends',
            include_dirs=[osp.join(ROOT, 'thirdparty/eigen')],
            sources=[
                'src/droid.cpp', 
                'src/droid_kernels.cu',
                'src/correlation_kernels.cu',
                'src/altcorr_kernel.cu',
            ],
            extra_compile_args={
                'cxx': ['-O3'],# '-Wno-deprecated-declarations'],
                'nvcc': [
                    '-O3',
                    # '-Xcompiler', '-Wno-deprecated-declarations',
                    # '-Xcudafe', '--diag_suppress=177,20014,20011,20236',
                    '-gencode=arch=compute_61,code=sm_61',
                    '-gencode=arch=compute_70,code=sm_70',
                    '-gencode=arch=compute_75,code=sm_75',
                    '-gencode=arch=compute_80,code=sm_80',
                    '-gencode=arch=compute_86,code=sm_86',
                    '-gencode=arch=compute_89,code=sm_89',                 
                    '-gencode=arch=compute_90,code=sm_90'
                ]
            }),
    ],
    cmdclass={ 'build_ext' : BuildExtension }
)

# setup(
#     name='lietorch',
#     version='0.3',
#     description='Lie Groups for PyTorch',
#     packages=['lietorch'],
#     package_dir={'': 'thirdparty/lietorch'},
#     ext_modules=[
#         CUDAExtension('lietorch_backends', 
#             include_dirs=[
#                 osp.join(ROOT, 'thirdparty/lietorch/lietorch/include'), 
#                 osp.join(ROOT, 'thirdparty/eigen')],
#             sources=[
#                 'thirdparty/lietorch/lietorch/src/lietorch.cpp', 
#                 'thirdparty/lietorch/lietorch/src/lietorch_gpu.cu',
#                 'thirdparty/lietorch/lietorch/src/lietorch_cpu.cpp'],
#             extra_compile_args={
#                 'cxx': ['-O2'],
#                 'nvcc': [
#                     '-O2',
#                     # '-Xcompiler', '-Wno-deprecated-declarations',
#                     # '-Xcudafe', '--diag_suppress=177,20014,20011,20236',
#                     '-gencode=arch=compute_61,code=sm_61', 
#                     '-gencode=arch=compute_70,code=sm_70', 
#                     '-gencode=arch=compute_75,code=sm_75',
#                     '-gencode=arch=compute_80,code=sm_80',
#                     '-gencode=arch=compute_86,code=sm_86',                 
#                     '-gencode=arch=compute_89,code=sm_89',                 
#                 ]
#             }),
#     ],
#     cmdclass={ 'build_ext' : BuildExtension }
# )